#!/usr/bin/env python3
"""
Collaborative Filtering for Movie Recommendations
Item-Item approach with Cosine Similarity
Optimized multi-threaded version with caching
"""

import sys
import math
import time
import pickle
import os
from multiprocessing import Pool, cpu_count
from itertools import combinations
from hashlib import md5


def get_cache_filename(ratings_file, similarity_threshold):
    """
    OBJ: Generate cache filename based on ratings file and threshold
    str, float -> str
    """
    # Create hash of ratings file content + threshold for cache key
    with open(ratings_file, 'rb') as f:
        file_hash = md5(f.read()).hexdigest()[:8]
    
    threshold_str = str(similarity_threshold).replace('.', '_')
    return f".cache_similarities_{file_hash}_t{threshold_str}.pkl"


def load_cache(cache_file):
    """
    OBJ: Load similarities from cache file
    str -> dict or None
    """
    if os.path.exists(cache_file):
        print(f"Found cache file: {cache_file}")
        try:
            with open(cache_file, 'rb') as f:
                similarities = pickle.load(f)
            print(f"Loaded {len(similarities)} movies from cache")
            return similarities
        except Exception as e:
            print(f"Error loading cache: {e}")
            return None
    return None


def save_cache(cache_file, similarities):
    """
    OBJ: Save similarities to cache file
    str, dict -> None
    """
    try:
        with open(cache_file, 'wb') as f:
            pickle.dump(similarities, f)
        print(f"Saved similarities to cache: {cache_file}")
    except Exception as e:
        print(f"Error saving cache: {e}")


def load_ratings(filename):
    """
    OBJ: Load ratings from CSV file with optimized parsing
    str -> dict, dict
    """
    users = {}
    movies = {}
    
    with open(filename, 'r') as f:
        next(f)  # Skip header
        
        for line in f:
            line = line.strip()
            if not line:
                continue
                
            parts = line.split(',')
            user_id = int(parts[0])
            movie_id = int(parts[1])
            rating = float(parts[2])
            
            if user_id not in users:
                users[user_id] = {}
            if movie_id not in movies:
                movies[movie_id] = {}
            
            users[user_id][movie_id] = rating
            movies[movie_id][user_id] = rating
    
    return users, movies


def cosine_similarity_optimized(movie1_ratings, movie2_ratings):
    """
    OBJ: Optimized cosine similarity computation
    dict, dict -> float
    """
    common_users = set(movie1_ratings.keys()) & set(movie2_ratings.keys())
    
    if not common_users:
        return 0.0
    
    # Vectorized computation
    ratings1 = [movie1_ratings[u] for u in common_users]
    ratings2 = [movie2_ratings[u] for u in common_users]
    
    dot_product = sum(r1 * r2 for r1, r2 in zip(ratings1, ratings2))
    magnitude1 = math.sqrt(sum(r1 * r1 for r1 in ratings1))
    magnitude2 = math.sqrt(sum(r2 * r2 for r2 in ratings2))
    
    if magnitude1 == 0 or magnitude2 == 0:
        return 0.0
    
    return dot_product / (magnitude1 * magnitude2)


def compute_similarities_batch(args):
    """
    OBJ: Compute similarities for a batch of movie pairs (worker function)
    tuple -> list
    """
    pairs, movies, similarity_threshold = args
    results = []
    
    for movie1, movie2 in pairs:
        sim = cosine_similarity_optimized(movies[movie1], movies[movie2])
        
        if sim >= similarity_threshold:
            results.append((movie1, movie2, sim))
    
    return results


def compute_movie_similarities(movies, similarity_threshold=0.0, num_threads=None, cache_file=None):
    """
    OBJ: Compute similarities with dynamic load balancing and caching
    dict, float, int, str -> dict
    """
    # Try to load from cache first
    if cache_file:
        cached = load_cache(cache_file)
        if cached is not None:
            print("Using cached similarities (skipping computation)")
            return cached
    
    if num_threads is None:
        num_threads = cpu_count()
    
    movie_ids = list(movies.keys())
    n_movies = len(movie_ids)
    
    print(f"Computing similarities for {n_movies} movies using {num_threads} threads...")
    print(f"Total comparisons: {n_movies * (n_movies - 1) // 2}")
    
    # Generate all movie pairs
    all_pairs = list(combinations(movie_ids, 2))
    total_pairs = len(all_pairs)
    
    # Distribute pairs evenly across threads for better load balancing
    batch_size = max(1, total_pairs // (num_threads * 4))  # Create more batches than threads
    batches = []
    
    for i in range(0, total_pairs, batch_size):
        batch_pairs = all_pairs[i:i + batch_size]
        batches.append((batch_pairs, movies, similarity_threshold))
    
    print(f"Created {len(batches)} batches for load balancing")
    
    # Process batches in parallel with dynamic scheduling
    similarities = {}
    with Pool(num_threads) as pool:
        batch_results = pool.map(compute_similarities_batch, batches)
    
    # Merge results
    for batch_result in batch_results:
        for movie1, movie2, sim in batch_result:
            if movie1 not in similarities:
                similarities[movie1] = []
            if movie2 not in similarities:
                similarities[movie2] = []
            
            similarities[movie1].append((movie2, sim))
            similarities[movie2].append((movie1, sim))
    
    # Sort each movie's similar movies by similarity (descending)
    for movie_id in similarities:
        similarities[movie_id].sort(key=lambda x: x[1], reverse=True)
    
    print("Similarity computation complete!")
    
    # Save to cache
    if cache_file:
        save_cache(cache_file, similarities)
    
    return similarities


def predict_rating(user_id, movie_id, users, similarities, k=10):
    """
    OBJ: Predict rating using Item-Item collaborative filtering
    int, int, dict, dict, int -> float or None
    """
    if movie_id not in similarities:
        return None
    
    user_ratings = users[user_id]
    similar_movies = similarities[movie_id][:k]
    
    weighted_sum = 0.0
    similarity_sum = 0.0
    
    for other_movie, sim in similar_movies:
        if other_movie in user_ratings:
            weighted_sum += sim * user_ratings[other_movie]
            similarity_sum += abs(sim)
    
    if similarity_sum == 0:
        return None
    
    return weighted_sum / similarity_sum


def recommend_for_user(user_id, users, movies, similarities, k=10, top_n=5):
    """
    OBJ: Generate top-N recommendations for a single user
    int, dict, dict, dict, int, int -> list
    """
    if user_id not in users:
        print(f"User {user_id} not found in dataset")
        return []
    
    user_ratings = users[user_id]
    all_movie_ids = set(movies.keys())
    unrated_movies = all_movie_ids - set(user_ratings.keys())
    
    # Predict ratings for unrated movies
    predictions = []
    for movie_id in unrated_movies:
        pred = predict_rating(user_id, movie_id, users, similarities, k)
        if pred is not None:
            predictions.append((movie_id, pred))
    
    # Sort by predicted rating and take top N
    predictions.sort(key=lambda x: x[1], reverse=True)
    return predictions[:top_n]


def recommend_batch(args):
    """
    OBJ: Generate recommendations for a batch of users (worker function)
    tuple -> list
    """
    user_batch, users, movies, similarities, k = args
    results = []
    
    for user_id in user_batch:
        user_recs = recommend_for_user(user_id, users, movies, similarities, k, top_n=1)
        if user_recs:
            movie_id, pred_rating = user_recs[0]
            results.append((user_id, movie_id, pred_rating))
    
    return results


def generate_recommendations(users, movies, similarities, num_threads=None, k=10):
    """
    OBJ: Generate recommendations for all users with parallel processing
    dict, dict, dict, int, int -> list
    """
    if num_threads is None:
        num_threads = cpu_count()
    
    user_ids = list(users.keys())
    print(f"Generating recommendations for {len(user_ids)} users using {num_threads} threads...")
    
    # Split users into batches
    batch_size = max(1, len(user_ids) // (num_threads * 2))
    batches = []
    
    for i in range(0, len(user_ids), batch_size):
        user_batch = user_ids[i:i + batch_size]
        batches.append((user_batch, users, movies, similarities, k))
    
    # Process batches in parallel
    with Pool(num_threads) as pool:
        batch_results = pool.map(recommend_batch, batches)
    
    # Merge results
    recommendations = []
    for batch_result in batch_results:
        recommendations.extend(batch_result)
    
    print("Recommendations complete!")
    return recommendations


if __name__ == '__main__':
    if len(sys.argv) < 2:
        print("Usage: ./collab_filter.py <ratings_file> [similarity_threshold] [user_id] [num_threads] [--no-cache]")
        print("*If user_id is provided, shows recommendations only for that user")
        print("*num_threads: number of threads to use (default: all CPU cores)")
        print("*--no-cache: disable similarity caching")
        sys.exit(1)

    ratings_file = sys.argv[1]
    similarity_threshold = float(sys.argv[2]) if len(sys.argv) > 2 else 0.0
    single_user_id = int(sys.argv[3]) if len(sys.argv) > 3 and sys.argv[3] != '--no-cache' else None
    num_threads = int(sys.argv[4]) if len(sys.argv) > 4 and sys.argv[4] != '--no-cache' else None
    use_cache = '--no-cache' not in sys.argv

    # Generate cache filename
    cache_file = get_cache_filename(ratings_file, similarity_threshold) if use_cache else None

    # Start total timing
    total_start = time.time()

    # Load data
    print("Loading ratings...")
    start_time = time.time()
    users, movies = load_ratings(ratings_file)
    load_time = time.time() - start_time
    print(f"Loaded {len(users)} users and {len(movies)} movies in {load_time:.2f} seconds")

    # Compute similarities (with caching)
    print("\nComputing similarities...")
    start_time = time.time()
    similarities = compute_movie_similarities(movies, similarity_threshold, num_threads, cache_file)
    sim_time = time.time() - start_time
    print(f"Similarity computation took {sim_time:.2f} seconds")

    # Generate recommendations
    print("\nGenerating recommendations...")
    start_time = time.time()
    if single_user_id is not None:
        # Single user mode
        print(f"\nTop recommendations for user {single_user_id}:")
        recommendations = recommend_for_user(single_user_id, users, movies, similarities)
        for movie_id, rating in recommendations:
            print(f"{single_user_id} {movie_id} {rating:.1f}")
    else:
        # All users mode - now parallelized!
        recommendations = generate_recommendations(users, movies, similarities, num_threads)
        for user_id, movie_id, rating in recommendations:
            print(f"{user_id} {movie_id} {rating:.1f}")
    rec_time = time.time() - start_time
    print(f"\nRecommendation generation took {rec_time:.2f} seconds")

    total_time = time.time() - total_start
    print(f"\n{'='*50}")
    print(f"TOTAL EXECUTION TIME: {total_time:.2f} seconds")
    print(f"  - Loading: {load_time:.2f}s ({load_time/total_time*100:.1f}%)")
    print(f"  - Similarities: {sim_time:.2f}s ({sim_time/total_time*100:.1f}%)")
    print(f"  - Recommendations: {rec_time:.2f}s ({rec_time/total_time*100:.1f}%)")
    print(f"{'='*50}")
    
    if use_cache and cache_file:
        print(f"\nCache file: {cache_file}")
        print("Next run will use cached similarities for instant results!")